import torch
import numpy as np
import time

import parsee
from torch.autograd import Variable
from utils import subsequent_mask

def log(data, timestamp):
    file = open(f'log/log-{timestamp}.txt', 'a')
    file.write(data)
    file.write('\n')
    file.close()

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

def evaluate(data , model, presrc = None):
    timestamp = time.time()

    if presrc:
        src = torch.from_numpy(np.array(presrc)).long().to("cuda")
        src = src.unsqueeze(0)
        src_mask = (src != 0).unsqueeze(-2)
        out = greedy_decode(model, src, src_mask, max_len = parsee.max_length, start_symbol = data.cn_word_dict["BOS"])
        translation = []
        for j in range(1, out.size(1)):
            sym = data.cn_index_dict[out[0, j].item()]
            if sym != 'EOS':
                translation.append(sym)
            else:
                break
        # print("translation: %s" % " ".join(translation))
    else:
        with torch.no_grad():
            for i in range(len(data.dev_en)):

                src = torch.from_numpy(np.array(data.dev_en[i])).long().to("cuda")
                src = src.unsqueeze(0)
                src_mask = (src != 0).unsqueeze(-2)

                out = greedy_decode(model, src, src_mask, max_len = parsee.max_length, start_symbol = data.cn_word_dict["BOS"])

                translation = []
                for j in range(1, out.size(1)):
                    sym = data.cn_index_dict[out[0, j].item()]
                    if sym != 'EOS':
                        translation.append(sym)
                    else:
                        break
                # print("translation: %s" % " ".join(translation))
    return translation
       
            

